Conversation
Apply per-parameter-class LR/eps scaling in setup_optimizer when use_mup=True on the model config. Mirrors the get_mup_config_overrides call added to MCore's setup_model_and_optimizer in NVIDIA/Megatron-LM#3058. The μP config fields (use_mup, mup_base_hidden_size, mup_width_mult, etc.) are already present via MCoreTransformerConfig inheritance — no model config changes needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add test_pretrain_with_mup to verify use_mup=True flows end-to-end through the DDP-wrapped model into setup_optimizer (width_mult=2.0) - Add training/conftest.py overriding ensure_test_data to skip GitHub asset download; training tests use MockGPTDatasetConfig and need no external data - Add INFO log in setup_optimizer when μP overrides are applied Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughIntroduces μP (Maximal Update Parameterization) support to the optimizer setup by retrieving model configuration, computing parameter-group learning rate overrides when enabled, and merging them into existing optimizer configuration. Includes comprehensive unit and functional tests. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Suggested labels
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/megatron/bridge/training/optim.py (1)
15-30: Move logger declaration below imports and align global naming rule.
loggerat Line 18 is declared between import blocks, and the global name does not match the configured global naming convention. Keep imports contiguous, then declare the module logger using the repository’s global naming style.♻️ Proposed refactor
import logging from typing import Optional, Union -logger = logging.getLogger(__name__) - from megatron.core.optimizer import ( MegatronOptimizer, OptimizerConfig, get_megatron_optimizer, get_mup_config_overrides, ) @@ from megatron.bridge.training.config import ( OptimizerConfigOverrideProvider, OptimizerConfigOverrideProviderContext, SchedulerConfig, ) +G_LOGGER = logging.getLogger(__name__)As per coding guidelines "Organize imports in order: future imports, standard library, third-party ... first-party ... separated by blank lines" and "Use upper snake_case and prefix 'G' for global variables".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/optim.py` around lines 15 - 30, Move the module logger declaration so all imports remain contiguous and rename it to follow the repo global naming convention (upper snake case with 'G' prefix): replace the current logger variable with G_LOGGER = logging.getLogger(__name__) declared immediately after the import block; update any local uses of logger in this module (e.g., any calls referencing logger) to use G_LOGGER instead; ensure only the import lines stay between the import section and the new G_LOGGER declaration so imports are not interrupted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/functional_tests/training/test_pretrain.py`:
- Around line 171-289: Add an explicit μP assertion in test_pretrain_with_mup to
verify μP logic executed: after constructing model_cfg with use_mup=True and
mup_base_hidden_size=1024 and before/after calling pretrain(forward_step),
inspect the optimizer or model config returned by the training setup (e.g., from
pretrain or the object created by Llama32ModelProvider1B) and assert a
μP-specific signal—such as an optimizer param-group lr scaling or a flag
indicating mup overrides were applied (look for methods/names like
setup_optimizer, get_model_config, use_mup, mup_width_mult, or the optimizer
param_groups returned by pretrain)—so the test fails if no μP overrides were
applied.
In `@tests/unit_tests/training/test_optim.py`:
- Line 51: The patched fixture argument mock_get_scheduler is unused in the test
functions (e.g., test_mup_disabled_skips_overrides, and the other tests flagged)
and triggers Ruff ARG002; fix by removing mock_get_scheduler from the function
signatures or renaming it to a leading-underscore name (e.g.,
_mock_get_scheduler) in each affected test to mark it as intentionally unused
(update signatures for test_mup_disabled_skips_overrides and the other flagged
test functions).
---
Nitpick comments:
In `@src/megatron/bridge/training/optim.py`:
- Around line 15-30: Move the module logger declaration so all imports remain
contiguous and rename it to follow the repo global naming convention (upper
snake case with 'G' prefix): replace the current logger variable with G_LOGGER =
logging.getLogger(__name__) declared immediately after the import block; update
any local uses of logger in this module (e.g., any calls referencing logger) to
use G_LOGGER instead; ensure only the import lines stay between the import
section and the new G_LOGGER declaration so imports are not interrupted.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b00040a0-1fe8-41d4-a6bb-d459fee452cc
📒 Files selected for processing (4)
src/megatron/bridge/training/optim.pytests/functional_tests/training/conftest.pytests/functional_tests/training/test_pretrain.pytests/unit_tests/training/test_optim.py
| @pytest.mark.run_only_on("GPU") | ||
| def test_pretrain_with_mup(self, tmp_path): | ||
| """ | ||
| Test end to end training with μP (Maximal Update Parameterization) enabled. | ||
|
|
||
| Verifies that use_mup=True flows through the full training stack: the model | ||
| config's mup_width_mult is computed by finalize(), get_model_config() on the | ||
| DDP-wrapped model still returns use_mup=True, and setup_optimizer applies the | ||
| per-parameter-class LR overrides without error. | ||
|
|
||
| Uses mup_base_hidden_size=1024 with hidden_size=2048 (width_mult=2.0) so that | ||
| the LR scaling is non-trivial and any failure to apply overrides would be visible. | ||
| """ | ||
| initialize_distributed() | ||
| shared_base_dir = broadcast_path(tmp_path) | ||
|
|
||
| tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") | ||
|
|
||
| if torch.distributed.get_rank() == 0: | ||
| os.makedirs(tensorboard_dir, exist_ok=True) | ||
|
|
||
| torch.distributed.barrier() | ||
|
|
||
| try: | ||
| global_batch_size = 8 | ||
| micro_batch_size = 1 | ||
| seq_length = 512 | ||
| total_iters = 5 | ||
|
|
||
| model_cfg = Llama32ModelProvider1B( | ||
| tensor_model_parallel_size=1, | ||
| pipeline_model_parallel_size=1, | ||
| context_parallel_size=1, | ||
| sequence_parallel=False, | ||
| attention_softmax_in_fp32=True, | ||
| pipeline_dtype=torch.bfloat16, | ||
| bf16=True, | ||
| seq_length=seq_length, | ||
| make_vocab_size_divisible_by=128, | ||
| vocab_size=None, | ||
| num_layers=1, | ||
| use_mup=True, | ||
| mup_base_hidden_size=1024, # width_mult = 2048/1024 = 2.0 | ||
| ) | ||
|
|
||
| cfg = ConfigContainer( | ||
| model=model_cfg, | ||
| train=TrainingConfig( | ||
| train_iters=total_iters, | ||
| global_batch_size=global_batch_size, | ||
| micro_batch_size=micro_batch_size, | ||
| exit_signal_handler=True, | ||
| ), | ||
| validation=ValidationConfig( | ||
| eval_interval=5, | ||
| eval_iters=2, | ||
| ), | ||
| optimizer=OptimizerConfig( | ||
| optimizer="adam", | ||
| bf16=True, | ||
| fp16=False, | ||
| adam_beta1=0.9, | ||
| adam_beta2=0.95, | ||
| adam_eps=1e-8, | ||
| use_distributed_optimizer=True, | ||
| clip_grad=1.0, | ||
| lr=3e-3, | ||
| weight_decay=0.01, | ||
| min_lr=1e-6, | ||
| ), | ||
| scheduler=SchedulerConfig( | ||
| start_weight_decay=0.033, | ||
| end_weight_decay=0.033, | ||
| weight_decay_incr_style="constant", | ||
| lr_decay_style="cosine", | ||
| lr_warmup_iters=1, | ||
| lr_warmup_init=0.0, | ||
| lr_decay_iters=total_iters, | ||
| override_opt_param_scheduler=True, | ||
| ), | ||
| ddp=DistributedDataParallelConfig( | ||
| check_for_nan_in_grad=True, | ||
| grad_reduce_in_fp32=True, | ||
| overlap_grad_reduce=True, | ||
| overlap_param_gather=True, | ||
| average_in_collective=True, | ||
| use_distributed_optimizer=True, | ||
| ), | ||
| dataset=MockGPTDatasetConfig( | ||
| random_seed=1234, | ||
| reset_attention_mask=False, | ||
| reset_position_ids=False, | ||
| eod_mask_loss=False, | ||
| seq_length=seq_length, | ||
| num_dataset_builder_threads=1, | ||
| data_sharding=True, | ||
| dataloader_type="single", | ||
| num_workers=1, | ||
| ), | ||
| logger=LoggerConfig( | ||
| log_interval=5, | ||
| tensorboard_dir=tensorboard_dir, | ||
| ), | ||
| tokenizer=TokenizerConfig( | ||
| tokenizer_type="NullTokenizer", | ||
| vocab_size=10000, | ||
| ), | ||
| checkpoint=CheckpointConfig( | ||
| save_interval=100, | ||
| ckpt_format="torch_dist", | ||
| ), | ||
| rng=RNGConfig(seed=1234), | ||
| ) | ||
|
|
||
| pretrain(cfg, forward_step) | ||
|
|
||
| finally: | ||
| clear_directories(tmp_path) | ||
|
|
There was a problem hiding this comment.
test_pretrain_with_mup needs an explicit μP assertion.
Right now this is a smoke test only. Even if μP overrides stop being applied, the test can still pass as long as training does not crash. Please add at least one explicit check that μP-specific logic executed (for example, assert the μP optimizer override signal/log appears, or assert an observable μP-derived optimizer-group property).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/functional_tests/training/test_pretrain.py` around lines 171 - 289, Add
an explicit μP assertion in test_pretrain_with_mup to verify μP logic executed:
after constructing model_cfg with use_mup=True and mup_base_hidden_size=1024 and
before/after calling pretrain(forward_step), inspect the optimizer or model
config returned by the training setup (e.g., from pretrain or the object created
by Llama32ModelProvider1B) and assert a μP-specific signal—such as an optimizer
param-group lr scaling or a flag indicating mup overrides were applied (look for
methods/names like setup_optimizer, get_model_config, use_mup, mup_width_mult,
or the optimizer param_groups returned by pretrain)—so the test fails if no μP
overrides were applied.
| @patch("megatron.bridge.training.optim._get_scheduler") | ||
| @patch("megatron.bridge.training.optim.get_megatron_optimizer") | ||
| @patch("megatron.bridge.training.optim.get_model_config") | ||
| def test_mup_disabled_skips_overrides(self, mock_get_model_config, mock_get_optimizer, mock_get_scheduler): |
There was a problem hiding this comment.
Fix Ruff ARG002: unused mock_get_scheduler arguments.
mock_get_scheduler is patched but not referenced in these tests, which triggers lint warnings.
✅ Minimal lint-only fix
- def test_mup_disabled_skips_overrides(self, mock_get_model_config, mock_get_optimizer, mock_get_scheduler):
+ def test_mup_disabled_skips_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
@@
- def test_mup_enabled_calls_overrides(self, mock_get_model_config, mock_get_optimizer, mock_get_scheduler):
+ def test_mup_enabled_calls_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
@@
- def test_mup_overrides_merged_with_existing(self, mock_get_model_config, mock_get_optimizer, mock_get_scheduler):
+ def test_mup_overrides_merged_with_existing(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
@@
- def test_mup_model_list_uses_first_chunk(self, mock_get_model_config, mock_get_optimizer, mock_get_scheduler):
+ def test_mup_model_list_uses_first_chunk(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):Also applies to: 70-70, 96-96, 133-133
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 51-51: Unused method argument: mock_get_scheduler
(ARG002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit_tests/training/test_optim.py` at line 51, The patched fixture
argument mock_get_scheduler is unused in the test functions (e.g.,
test_mup_disabled_skips_overrides, and the other tests flagged) and triggers
Ruff ARG002; fix by removing mock_get_scheduler from the function signatures or
renaming it to a leading-underscore name (e.g., _mock_get_scheduler) in each
affected test to mark it as intentionally unused (update signatures for
test_mup_disabled_skips_overrides and the other flagged test functions).
| OptimizerConfigOverrideProviderContext(scheduler_config, optimizer_config, model) | ||
| ) | ||
|
|
||
| # Apply μP optimizer scaling if enabled on the model config |
There was a problem hiding this comment.
is it possible to do this in config container. validate()? all the overriding logic better put in the same place. If not possible, okay to leave it here.
There was a problem hiding this comment.
Good point. I think this can't live in ConfigContainer.validate() because the μP optimizer overrides require the post-DDP-wrapped model — specifically, get_model_config() needs to unwrap the DDP/FSDP shell to reach the underlying TransformerConfig. At validate() time the model hasn't been constructed yet (the config container is built before setup_model() is called). By the time setup_optimizer() is called, the model is fully wrapped and get_model_config() can safely retrieve use_mup and mup_width_mult.
- optim.py: move G_LOGGER declaration below all imports, rename logger → G_LOGGER - test_optim.py: rename mock_get_scheduler → _mock_get_scheduler (fix ARG002 lint) - test_pretrain.py: add explicit μP assertion via caplog to verify overrides applied Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- move import logging to top of test_pretrain.py - add blank line before G_LOGGER to satisfy isort (I001) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Tests